from matplotlib.testing.decorators import image_comparison as img_comp

from pyvrp import (
    CostEvaluator,
    Population,
    PopulationParams,
    RandomNumberGenerator,
    Result,
    Route,
    Solution,
    Statistics,
    plotting,
)
from pyvrp.diversity import broken_pairs_distance
from tests.helpers import read, read_solution

IMG_KWARGS = dict(remove_text=True, tol=8, extensions=["png"], style="mpl20")


@img_comp(["plot_solution", "plot_solution_with_customers"], **IMG_KWARGS)
def test_plot_solution():
    """
    Compares plots generated by ``plot_solution``.
    """
    data = read("data/RC208.vrp", round_func="trunc")
    bks = read_solution("data/RC208.sol")

    sol = Solution(data, bks)

    plotting.plot_solution(sol, data)
    plotting.plot_solution(sol, data, plot_clients=True)


@img_comp(["plot_solution_multiple_depots"], **IMG_KWARGS)
def test_plot_solution_multiple_depots():
    """
    Tests that ``plot_solution`` correctly displays solutions with multiple
    depots (e.g., depots are plotted and lines to/from depots are correctly
    drawn).
    """
    data = read("data/OkSmallMultipleDepots.txt", round_func="trunc")
    routes = [
        Route(data, [2, 3], vehicle_type=0),
        Route(data, [4], vehicle_type=1),
    ]

    plotting.plot_solution(Solution(data, routes), data)


@img_comp(["plot_solution_optional_clients"], **IMG_KWARGS)
def test_plot_solution_optional_clients(ok_small_prizes):
    """
    Tests that plot_solution() correctly displays unvisited locations when the
    instance contains optional clients and ``plot_clients`` is True.
    """
    routes = [
        Route(ok_small_prizes, [2, 3], vehicle_type=0),
        Route(ok_small_prizes, [1], vehicle_type=0),
    ]
    plotting.plot_solution(
        Solution(ok_small_prizes, routes),
        ok_small_prizes,
        plot_clients=True,
    )


@img_comp(["plot_result"], **IMG_KWARGS)
def test_plot_result():
    """
    Compares plots generated by ``plot_result``.
    """
    num_iterations = 100

    data = read("data/RC208.vrp", round_func="trunc")
    bks = read_solution("data/RC208.sol")
    cost_evaluator = CostEvaluator(20, 6, 0)
    rng = RandomNumberGenerator(seed=42)

    params = PopulationParams()
    pop = Population(broken_pairs_distance, params=params)
    for _ in range(params.min_pop_size):
        pop.add(Solution.make_random(data, rng), cost_evaluator)

    stats = Statistics()

    for i in range(num_iterations):
        if i == num_iterations // 2:
            # Make sure we insert a feasible solution
            sol = Solution(data, bks)
        else:
            sol = Solution.make_random(data, rng)

        pop.add(sol, cost_evaluator)
        stats.collect_from(pop, cost_evaluator)

        # Hacky to produce deterministic result
        stats.runtimes[-1] = i % 3

    res = Result(Solution(data, bks), stats, num_iterations, 0.0)
    plotting.plot_result(res, data)


@img_comp(["plot_instance"], **IMG_KWARGS)
def test_plot_instance():
    """
    Compares plots generated by ``plot_instance``.
    """
    data = read("data/RC208.vrp", round_func="trunc")
    plotting.plot_instance(data)


@img_comp(["plot_instance_multiple_depots"], **IMG_KWARGS)
def test_plot_instance_multiple_depots():
    """
    Tests that ``plot_instance`` correctly displays instances with multiple
    depots.
    """
    data = read("data/OkSmallMultipleDepots.txt", round_func="trunc")
    plotting.plot_instance(data)


@img_comp(["plot_route_schedule"], **IMG_KWARGS)
def test_plot_route_schedule():
    """
    Compares plots generated by ``plot_route_schedule``.
    """
    data = read("data/RC208.vrp", round_func="trunc")
    bks = read_solution("data/RC208.sol")
    sol = Solution(data, bks)
    plotting.plot_route_schedule(data, sol.routes()[0])
